{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e82e44e1",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d67b3443",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import torch\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b64635de",
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_paths = glob.glob(\"demo_data/obs_*.npy\")\n",
    "pose_delta_paths = glob.glob(\"demo_data/pose_delta_*.npy\")\n",
    "device = torch.device(\"cuda:0\")\n",
    "\n",
    "obs = torch.stack([\n",
    "    torch.from_numpy(np.load(obs_path)) \n",
    "    for obs_path in obs_paths\n",
    "]).unsqueeze(0).to(device)\n",
    "\n",
    "pose_delta = torch.stack([\n",
    "    torch.from_numpy(np.load(pose_delta_path)) \n",
    "    for pose_delta_path in pose_delta_paths\n",
    "]).unsqueeze(0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "007ddec2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# frame containing (RGB, depth, segmentation) of shape \n",
    "# (batch_size, sequence_length, 3 + 1 + num_sem_categories, frame_height, frame_width)\n",
    "obs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb32c249",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sensor pose delta (dy, dx, dtheta) since last frame of shape \n",
    "# (batch_size, sequence_length, 3)\n",
    "pose_delta.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4578e76",
   "metadata": {},
   "source": [
    "## Map Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "487b019c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from home_robot.agent.mapping.dense.semantic.categorical_2d_semantic_map_state import Categorical2DSemanticMapState\n",
    "from home_robot.agent.mapping.dense.semantic.categorical_2d_semantic_map_module import Categorical2DSemanticMapModule\n",
    "\n",
    "# State holds global and local map and sensor pose\n",
    "# See class definition for argument info\n",
    "semantic_map = Categorical2DSemanticMapState(\n",
    "    device=device,\n",
    "    num_environments=1,\n",
    "    num_sem_categories=16,\n",
    "    map_resolution=5,\n",
    "    map_size_cm=4800,\n",
    "    global_downscaling=2,\n",
    ")\n",
    "semantic_map.init_map_and_pose()\n",
    "\n",
    "# Module is responsible for updating the local and global maps and poses\n",
    "# See class definition for argument info\n",
    "semantic_map_module = Categorical2DSemanticMapModule(\n",
    "    frame_height=480,\n",
    "    frame_width=640,\n",
    "    camera_height=0.88,\n",
    "    hfov=79.0,\n",
    "    num_sem_categories=16,\n",
    "    map_size_cm=4800,\n",
    "    map_resolution=5,\n",
    "    vision_range=100,\n",
    "    global_downscaling=2,\n",
    "    du_scale=4,\n",
    "    cat_pred_threshold=5.0,\n",
    "    exp_pred_threshold=1.0,\n",
    "    map_pred_threshold=1.0,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14f7b71e",
   "metadata": {},
   "source": [
    "## Map Update"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d51bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_length = obs.shape[1]\n",
    "num_sem_categories = obs.shape[2] - 4\n",
    "dones = torch.tensor([False] * sequence_length).unsqueeze(0).to(device)\n",
    "update_global = torch.tensor([True] * sequence_length).unsqueeze(0).to(device)\n",
    "\n",
    "(\n",
    "    seq_map_features,\n",
    "    semantic_map.local_map,\n",
    "    semantic_map.global_map,\n",
    "    seq_local_pose,\n",
    "    seq_global_pose,\n",
    "    seq_lmb,\n",
    "    seq_origins,\n",
    ") = semantic_map_module(\n",
    "    obs,\n",
    "    pose_delta,\n",
    "    dones,\n",
    "    update_global,\n",
    "    semantic_map.local_map,\n",
    "    semantic_map.global_map,\n",
    "    semantic_map.local_pose,\n",
    "    semantic_map.global_pose,\n",
    "    semantic_map.lmb,\n",
    "    semantic_map.origins,\n",
    ")\n",
    "    \n",
    "semantic_map.local_pose = seq_local_pose[:, -1]\n",
    "semantic_map.global_pose = seq_global_pose[:, -1]\n",
    "semantic_map.lmb = seq_lmb[:, -1]\n",
    "semantic_map.origins = seq_origins[:, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16e7df32",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Global semantic map of shape \n",
    "# (batch_size, num_channels, M, M)\n",
    "#\n",
    "# where num_channels = 4 + num_sem_categories\n",
    "# 0: obstacle map\n",
    "# 1: explored area\n",
    "# 2: current agent location\n",
    "# 3: past agent locations\n",
    "# 4, 5, 6, .., num_sem_categories + 3: semantic categories\n",
    "semantic_map.global_map.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee53d178",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Local semantic map visualization\n",
    "\n",
    "import cv2\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "from home_robot.agent.perception.detection.coco_maskrcnn.coco_categories import (\n",
    "    coco_categories_color_palette,\n",
    ")\n",
    "\n",
    "map_color_palette = [\n",
    "    1.0,\n",
    "    1.0,\n",
    "    1.0,  # empty space\n",
    "    0.6,\n",
    "    0.6,\n",
    "    0.6,  # obstacles\n",
    "    0.95,\n",
    "    0.95,\n",
    "    0.95,  # explored area\n",
    "    0.96,\n",
    "    0.36,\n",
    "    0.26,  # visited area\n",
    "    *coco_categories_color_palette,\n",
    "]\n",
    "map_color_palette = [int(x * 255.0) for x in map_color_palette]\n",
    "\n",
    "semantic_categories_map = semantic_map.get_semantic_map(0)\n",
    "obstacle_map = semantic_map.get_obstacle_map(0)\n",
    "explored_map = semantic_map.get_explored_map(0)\n",
    "visited_map = semantic_map.get_visited_map(0)\n",
    "\n",
    "semantic_categories_map += 4\n",
    "no_category_mask = semantic_categories_map == 4 + num_sem_categories - 1\n",
    "obstacle_mask = np.rint(obstacle_map) == 1\n",
    "explored_mask = np.rint(explored_map) == 1\n",
    "visited_mask = visited_map == 1\n",
    "semantic_categories_map[no_category_mask] = 0\n",
    "semantic_categories_map[np.logical_and(no_category_mask, explored_mask)] = 2\n",
    "semantic_categories_map[np.logical_and(no_category_mask, obstacle_mask)] = 1\n",
    "semantic_categories_map[visited_mask] = 3\n",
    "\n",
    "semantic_map_vis = Image.new(\"P\", semantic_categories_map.shape)\n",
    "semantic_map_vis.putpalette(map_color_palette)\n",
    "semantic_map_vis.putdata(semantic_categories_map.flatten().astype(np.uint8))\n",
    "semantic_map_vis = semantic_map_vis.convert(\"RGB\")\n",
    "semantic_map_vis = np.flipud(semantic_map_vis)\n",
    "plt.imshow(semantic_map_vis)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:home_robot]",
   "language": "python",
   "name": "conda-env-home_robot-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
